#pragma once

// #include "db/index/hook.hpp"
#include "db/index/index.hpp"

namespace vectordb {

static float
L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
#if 0
    float* pVect1 = (float*)pVect1v;
    float* pVect2 = (float*)pVect2v;
    size_t qty = *((size_t*)qty_ptr);

    float res = 0;
    for (size_t i = 0; i < qty; i++) {
        float t = *pVect1 - *pVect2;
        pVect1++;
        pVect2++;
        res += t * t;
    }
    return (res);
#else
  return vectordb::fvec_L2sqr((const float*)pVect1v, (const float*)pVect2v, *((size_t*)qty_ptr));
#endif
}

#if defined(USE_AVX512)

// Favor using AVX512 if available.
static float
L2SqrSIMD16ExtAVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  float* pVect1 = (float*)pVect1v;
  float* pVect2 = (float*)pVect2v;
  size_t qty = *((size_t*)qty_ptr);
  float PORTABLE_ALIGN64 TmpRes[16];
  size_t qty16 = qty >> 4;

  const float* pEnd1 = pVect1 + (qty16 << 4);

  __m512 diff, v1, v2;
  __m512 sum = _mm512_set1_ps(0);

  while (pVect1 < pEnd1) {
    v1 = _mm512_loadu_ps(pVect1);
    pVect1 += 16;
    v2 = _mm512_loadu_ps(pVect2);
    pVect2 += 16;
    diff = _mm512_sub_ps(v1, v2);
    // sum = _mm512_fmadd_ps(diff, diff, sum);
    sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
  }

  _mm512_store_ps(TmpRes, sum);
  float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] +
              TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];

  return (res);
}
#endif

#if defined(USE_AVX)

// Favor using AVX if available.
static float
L2SqrSIMD16ExtAVX(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  float* pVect1 = (float*)pVect1v;
  float* pVect2 = (float*)pVect2v;
  size_t qty = *((size_t*)qty_ptr);
  float PORTABLE_ALIGN32 TmpRes[8];
  size_t qty16 = qty >> 4;

  const float* pEnd1 = pVect1 + (qty16 << 4);

  __m256 diff, v1, v2;
  __m256 sum = _mm256_set1_ps(0);

  while (pVect1 < pEnd1) {
    v1 = _mm256_loadu_ps(pVect1);
    pVect1 += 8;
    v2 = _mm256_loadu_ps(pVect2);
    pVect2 += 8;
    diff = _mm256_sub_ps(v1, v2);
    sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));

    v1 = _mm256_loadu_ps(pVect1);
    pVect1 += 8;
    v2 = _mm256_loadu_ps(pVect2);
    pVect2 += 8;
    diff = _mm256_sub_ps(v1, v2);
    sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
  }

  _mm256_store_ps(TmpRes, sum);
  return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}

#endif

#if defined(USE_SSE)

static float
L2SqrSIMD16ExtSSE(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  float* pVect1 = (float*)pVect1v;
  float* pVect2 = (float*)pVect2v;
  size_t qty = *((size_t*)qty_ptr);
  float PORTABLE_ALIGN32 TmpRes[8];
  size_t qty16 = qty >> 4;

  const float* pEnd1 = pVect1 + (qty16 << 4);

  __m128 diff, v1, v2;
  __m128 sum = _mm_set1_ps(0);

  while (pVect1 < pEnd1) {
    //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
    v1 = _mm_loadu_ps(pVect1);
    pVect1 += 4;
    v2 = _mm_loadu_ps(pVect2);
    pVect2 += 4;
    diff = _mm_sub_ps(v1, v2);
    sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

    v1 = _mm_loadu_ps(pVect1);
    pVect1 += 4;
    v2 = _mm_loadu_ps(pVect2);
    pVect2 += 4;
    diff = _mm_sub_ps(v1, v2);
    sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

    v1 = _mm_loadu_ps(pVect1);
    pVect1 += 4;
    v2 = _mm_loadu_ps(pVect2);
    pVect2 += 4;
    diff = _mm_sub_ps(v1, v2);
    sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

    v1 = _mm_loadu_ps(pVect1);
    pVect1 += 4;
    v2 = _mm_loadu_ps(pVect2);
    pVect2 += 4;
    diff = _mm_sub_ps(v1, v2);
    sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
  }

  _mm_store_ps(TmpRes, sum);
  return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
DENSE_DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;

static float
L2SqrSIMD16ExtResiduals(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  size_t qty = *((size_t*)qty_ptr);
  size_t qty16 = qty >> 4 << 4;
  float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
  float* pVect1 = (float*)pVect1v + qty16;
  float* pVect2 = (float*)pVect2v + qty16;

  size_t qty_left = qty - qty16;
  float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
  return (res + res_tail);
}
#endif

#if defined(USE_SSE)
static float
L2SqrSIMD4Ext(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  float PORTABLE_ALIGN32 TmpRes[8];
  float* pVect1 = (float*)pVect1v;
  float* pVect2 = (float*)pVect2v;
  size_t qty = *((size_t*)qty_ptr);

  size_t qty4 = qty >> 2;

  const float* pEnd1 = pVect1 + (qty4 << 2);

  __m128 diff, v1, v2;
  __m128 sum = _mm_set1_ps(0);

  while (pVect1 < pEnd1) {
    v1 = _mm_loadu_ps(pVect1);
    pVect1 += 4;
    v2 = _mm_loadu_ps(pVect2);
    pVect2 += 4;
    diff = _mm_sub_ps(v1, v2);
    sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
  }
  _mm_store_ps(TmpRes, sum);
  return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}

static float
L2SqrSIMD4ExtResiduals(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
  size_t qty = *((size_t*)qty_ptr);
  size_t qty4 = qty >> 2 << 2;

  float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
  size_t qty_left = qty - qty4;

  float* pVect1 = (float*)pVect1v + qty4;
  float* pVect2 = (float*)pVect2v + qty4;
  float res_tail = L2Sqr(pVect1, pVect2, &qty_left);

  return (res + res_tail);
}
#endif

class L2Space : public SpaceInterface<float> {
  DenseVecDistFunc<float> fstdistfunc_;
  size_t data_size_;
  size_t dim_;

 public:
  L2Space(size_t dim) {
    fstdistfunc_ = L2Sqr;
#if 0
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
        if (AVX512Capable())
            L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
        else if (AVXCapable())
            L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#elif defined(USE_AVX)
        if (AVXCapable())
            L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif

        if (dim % 16 == 0)
            fstdistfunc_ = L2SqrSIMD16Ext;
        else if (dim % 4 == 0)
            fstdistfunc_ = L2SqrSIMD4Ext;
        else if (dim > 16)
            fstdistfunc_ = L2SqrSIMD16ExtResiduals;
        else if (dim > 4)
            fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
#endif
    dim_ = dim;
    data_size_ = dim * sizeof(float);
  }

  size_t
  get_data_size() {
    return data_size_;
  }

  DenseVecDistFunc<float>
  get_dist_func() {
    return fstdistfunc_;
  }

  void*
  get_dist_func_param() {
    return &dim_;
  }

  ~L2Space() {
  }
};

static int
L2SqrI4x(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
  size_t qty = *((size_t*)qty_ptr);
  int res = 0;
  unsigned char* a = (unsigned char*)pVect1;
  unsigned char* b = (unsigned char*)pVect2;

  qty = qty >> 2;
  for (size_t i = 0; i < qty; i++) {
    res += ((*a) - (*b)) * ((*a) - (*b));
    a++;
    b++;
    res += ((*a) - (*b)) * ((*a) - (*b));
    a++;
    b++;
    res += ((*a) - (*b)) * ((*a) - (*b));
    a++;
    b++;
    res += ((*a) - (*b)) * ((*a) - (*b));
    a++;
    b++;
  }
  return (res);
}

static int
L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
  size_t qty = *((size_t*)qty_ptr);
  int res = 0;
  unsigned char* a = (unsigned char*)pVect1;
  unsigned char* b = (unsigned char*)pVect2;

  for (size_t i = 0; i < qty; i++) {
    res += ((*a) - (*b)) * ((*a) - (*b));
    a++;
    b++;
  }
  return (res);
}

class L2SpaceI : public SpaceInterface<int> {
  DenseVecDistFunc<int> fstdistfunc_;
  size_t data_size_;
  size_t dim_;

 public:
  L2SpaceI(size_t dim) {
    if (dim % 4 == 0) {
      fstdistfunc_ = L2SqrI4x;
    } else {
      fstdistfunc_ = L2SqrI;
    }
    dim_ = dim;
    data_size_ = dim * sizeof(unsigned char);
  }

  size_t
  get_data_size() {
    return data_size_;
  }

  DenseVecDistFunc<int>
  get_dist_func() {
    return fstdistfunc_;
  }

  void*
  get_dist_func_param() {
    return &dim_;
  }

  ~L2SpaceI() {
  }
};

}  // namespace vectordb